{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intended for others to play with the code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from cont_speech_experiment import ContinuousSpeechExperiment, ClasswiseDataset\n",
    "from nupic.research.support import parse_config\n",
    "\n",
    "from nupic.research.frameworks.continuous_learning.utils import clear_labels, freeze_output_layer\n",
    "from exp_lesparse import LeSparseNet\n",
    "\n",
    "import os\n",
    "\n",
    "from nupic.research.frameworks.pytorch.model_utils import evaluate_model\n",
    "from nupic.research.frameworks.continuous_learning.dendrite_layers import (\n",
    "    DendriteLayer, DendriteInput, DendriteOutput\n",
    ")\n",
    "from nupic.torch.modules import (\n",
    "    Flatten,\n",
    "    KWinners,\n",
    "    KWinners2d,\n",
    "    SparseWeights,\n",
    "    SparseWeights2d,\n",
    ")\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating optimizer with learning rate= 0.01\n"
     ]
    }
   ],
   "source": [
    "config_file = \"experiments.cfg\"\n",
    "with open(config_file) as cf:\n",
    "    config_init = parse_config(cf)\n",
    "    \n",
    "exp = \"sparseCNN2\"\n",
    "\n",
    "config = config_init[exp]\n",
    "config[\"name\"] = exp\n",
    "config[\"use_dendrites\"] = True\n",
    "config[\"use_batch_norm\"] = False\n",
    "config[\"cnn_out_channels\"] = (64, 64)\n",
    "config[\"cnn_percent_on\"] = (0.12, 0.07)\n",
    "config[\"cnn_weight_sparsity\"] = (0.15, 0.05)\n",
    "config[\"dendrites_per_cell\"] = 2\n",
    "config[\"batch_size\"] = 64\n",
    "experiment = ContinuousSpeechExperiment(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_no_params(model):\n",
    "    model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "    params = sum([np.prod(p.size()) for p in model_parameters])\n",
    "    return params\n",
    "\n",
    "def clear_labels(labels, n_classes=5):\n",
    "    indices = np.arange(n_classes)\n",
    "    out = np.delete(indices, labels)\n",
    "    return out\n",
    "\n",
    "def reshape_xs(data, target, device=torch.device(\"cuda\"), non_blocking=None):\n",
    "    # for evaluate_model if you flatten the images\n",
    "    batch_size = data.shape[0]\n",
    "    data = data.reshape(batch_size, 32*32).to(device)\n",
    "    target = target.to(device)\n",
    "    return data, target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ToyNetwork(nn.Module):\n",
    "    def __init__(self, dpc=3,\n",
    "                 cnn_w_sparsity=0.05,\n",
    "                 linear_w_sparsity=0.5,\n",
    "                 cat_w_sparsity=0.01,\n",
    "                n_classes=4,\n",
    "                do_cat=False):\n",
    "        super(ToyNetwork, self).__init__()\n",
    "        conv_channels = 128\n",
    "        self.n_classes = n_classes\n",
    "        self.do_cat = do_cat\n",
    "        self.conv1 = SparseWeights2d(nn.Conv2d(in_channels=1,\n",
    "                                              out_channels=conv_channels,\n",
    "                                              kernel_size=10,\n",
    "                                              padding=0,\n",
    "                                              stride=1,), cnn_w_sparsity)\n",
    "        self.kwin1 = KWinners2d(conv_channels, percent_on=0.1)\n",
    "        self.bn = nn.BatchNorm2d(conv_channels, affine=False)\n",
    "        self.mp1 = nn.MaxPool2d(kernel_size=2)\n",
    "        self.flatten = Flatten()\n",
    "        \n",
    "        self.d1 = DendriteLayer(in_dim=int(conv_channels/64)*7744,\n",
    "                               out_dim=1000,\n",
    "                               dendrites_per_neuron=dpc)\n",
    "        \n",
    "        self.linear = SparseWeights(nn.Linear(1000, n_classes+1), linear_w_sparsity)\n",
    "        \n",
    "        if self.do_cat:\n",
    "            self.cat = SparseWeights(nn.Linear(n_classes+1, 1000*dpc), cat_w_sparsity)\n",
    "        \n",
    "    def forward(self, x, label=None, batch_norm=False):\n",
    "        y = self.conv1(x)\n",
    "        if batch_norm:\n",
    "            y = self.bn(y)\n",
    "        y = self.kwin1(self.mp1(y))\n",
    "        y = self.flatten(y)\n",
    "        if (label is not None) and (self.do_cat == True) :\n",
    "            yhat = torch.eye(self.n_classes+1)[label]\n",
    "            y = self.d1(y, torch.sigmoid(self.cat(yhat.cuda())))\n",
    "        else:\n",
    "            y = self.d1(y)\n",
    "        y = F.log_softmax(self.linear(y), dim=1)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = ToyNetwork(dpc=4, cat_w_sparsity=0.01).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=0.)\n",
    "criterion = F.nll_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.9635416666666666, 0.84375]\n",
      "[0.9765625, 0.7578125]\n"
     ]
    }
   ],
   "source": [
    "train_inds = np.arange(1,5).reshape(2,2)\n",
    "losses = []\n",
    "for i in range(len(train_inds)):\n",
    "    experiment.combine_classes(train_inds[i])\n",
    "    loader = experiment.train_loader\n",
    "    for j, (x, y) in enumerate(loader):\n",
    "        batch_size = x.shape[0]\n",
    "        opt.zero_grad()\n",
    "       \n",
    "        out = net(x.cuda()) #  no categorical projection\n",
    "#         out = net(x.cuda(), y.cuda()) # categorical projection\n",
    "        loss = criterion(out, y.cuda())\n",
    "        loss.backward()\n",
    "        losses.append(loss.detach().cpu().numpy())\n",
    "        \n",
    "        freeze_output_layer(net, clear_labels(train_inds),\n",
    "                           layer_type=\"kwinner\", linear_number='')\n",
    "        \n",
    "        opt.step()\n",
    "    print([evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"]\n",
    "           for k in train_inds[i]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, 0.0, 0.9765625, 0.7578125]"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test on each class trained sequentially \n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.2808988764044944, 0.04924242424242424, 0.684, 0.0]"
      ]
     },
     "execution_count": 303,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Adam, dpc=4, cat sparsity=0.01, output freezing, output_size=5\n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, 0.007575757575757576, 0.872, 0.31048387096774194]"
      ]
     },
     "execution_count": 308,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Same\n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.5186314861691"
      ]
     },
     "execution_count": 295,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.log10(get_no_params(net))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f4d1ef38f10>]"
      ]
     },
     "execution_count": 304,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAYS0lEQVR4nO3de3Sc9X3n8ffHdxsbX2WvwRfZxCQkXARRCFmSlAYSbk0hbZLF3UPZhNbJNuxJTrPn1CQ9DW2abZIt0EtasrC4QJtwaQiEPUCKAxQCCRcZG1uAwVdsyUKSrYsvulnSd/+YR3gsjy6j0ejy9PM6Z46e+T3PzHzm8aOPn3lmRo8iAjMzS5cJox3AzMyGn8vdzCyFXO5mZinkcjczSyGXu5lZCk0a7QAACxYsiNLS0tGOYWY2rmzYsGF/RJTkmjcmyr20tJSKiorRjmFmNq5IeruveT4sY2aWQi53M7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKudzN+lBZ3czGPY2jHcNsSMbEl5jMxqLf+vvnANj93StHOYlZ/rznbmaWQi53M7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKudzNzFLI5W5mlkIudzOzFHK5m5mlkMvdzCyFBix3Sesk1UmqzBq7X9Km5LJb0qZkvFRSa9a8HxYzvJmZ5TaYvwp5F/AD4J6egYj4Lz3Tkm4GmrOW3xERZcMV0MzM8jdguUfEs5JKc82TJODzwCeGN5aZmRWi0GPuHwNqI2Jb1tgKSRslPSPpY33dUNIaSRWSKurr6wuMYWZm2Qot99XAvVnXa4BlEXEu8MfAjyWdnOuGEXF7RJRHRHlJSUmBMczMLNuQy13SJOB3gPt7xiKiPSIOJNMbgB3A6YWGNDOz/BSy534JsDUiqnoGJJVImphMrwRWATsLi2hmZvkazEch7wV+DbxXUpWk65NZ13D8IRmAjwObJb0K/AT4ckQ0DGdgMzMb2GA+LbO6j/H/lmPsQeDBwmOZmVkh/A1VM7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKudzNzFLI5W5mlkIudzOzFHK5m5mlkMvdzCyFXO5mZinkcjczSyGXu5lZCrnczcxSyOVuZpZCLnczsxQazGn21kmqk1SZNXaTpGpJm5LLFVnzbpS0XdKbki4tVnAzM+vbYPbc7wIuyzF+a0SUJZfHACS9n8y5VT+Q3OYfe06YbWZmI2fAco+IZ4HBnuT6KuC+iGiPiF3AduD8AvKZmdkQFHLM/QZJm5PDNnOTsVOBvVnLVCVjJ5C0RlKFpIr6+voCYpiZWW9DLffbgNOAMqAGuDkZV45lI9cdRMTtEVEeEeUlJSVDjGFmZrkMqdwjojYiuiKiG7iDY4deqoClWYsuAfYVFtHMzPI1pHKXtDjr6meAnk/SPAJcI2mqpBXAKuClwiKamVm+Jg20gKR7gYuABZKqgG8BF0kqI3PIZTfwJYCIeE3SA8DrQCfwlYjoKk50MzPry4DlHhGrcwzf2c/y3wG+U0goMzMrjL+hamaWQi53M7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKudzNzFLI5W5mlkIudzOzFHK5m5mlkMvdzCyFXO5mZinkcjczSyGXu5lZCrnczcxSyOVuZpZCA5a7pHWS6iRVZo39b0lbJW2W9JCkOcl4qaRWSZuSyw+LGd7MzHIbzJ77XcBlvcbWA2dGxNnAW8CNWfN2RERZcvny8MQ0M7N8DFjuEfEs0NBr7ImI6EyuvgAsKUI2MzMbouE45v5F4PGs6yskbZT0jKSP9XUjSWskVUiqqK+vH4YYZmbWo6Byl/RNoBP4UTJUAyyLiHOBPwZ+LOnkXLeNiNsjojwiyktKSgqJYWZmvQy53CVdB/wW8F8jIgAioj0iDiTTG4AdwOnDEdTMzAZvSOUu6TLgT4DfjoiWrPESSROT6ZXAKmDncAQ1M7PBmzTQApLuBS4CFkiqAr5F5tMxU4H1kgBeSD4Z83HgLyR1Al3AlyOiIecdm5lZ0QxY7hGxOsfwnX0s+yDwYKGhzMysMP6GqplZCrnczcxSyOVuZpZCLnczsxRyuZuZpZDL3cwshVzuZmYp5HI3M0shl7uZWQq53M3MUsjlbmaWQi53M7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKDarcJa2TVCepMmtsnqT1krYlP+cm45L0d5K2S9os6bxihTczs9wGu+d+F3BZr7G1wJMRsQp4MrkOcDmZE2OvAtYAtxUe08zM8jGoco+IZ4HeJ7q+Crg7mb4buDpr/J7IeAGYI2nxcIQ1M7PBKeSY+6KIqAFIfi5Mxk8F9mYtV5WMHUfSGkkVkirq6+sLiGFmZr0V4w1V5RiLEwYibo+I8ogoLykpKUIMM7P/uAop99qewy3Jz7pkvApYmrXcEmBfAY9jZmZ5KqTcHwGuS6avA36WNf77yadmLgCaew7fmJnZyJg0mIUk3QtcBCyQVAV8C/gu8ICk64E9wOeSxR8DrgC2Ay3AF4Y5s5mZDWBQ5R4Rq/uYdXGOZQP4SiGhzMysMP6GqplZCrnczcxSyOVuZpZCLnczsxRyuZuZpZDL3cwshVzuZmYp5HI3G8ChtqOjHcEsby53swFMnJDrb+GZjW0udzOzFHK5mw0gTviD1WZjn8vdzCyFXO5mA/COu41HLnezAYSPy9g45HI3M0shl7vZALzfbuORy91sAD4qY+PRoM7ElIuk9wL3Zw2tBP4MmAP8IVCfjH8jIh4bckIzM8vbkPfcI+LNiCiLiDLgg2TOl/pQMvvWnnkudhvv9hxoGe0IZnkbrsMyFwM7IuLtYbo/szHj0z94brQjmOVtuMr9GuDerOs3SNosaZ2kubluIGmNpApJFfX19bkWMTOzISq43CVNAX4b+Ndk6DbgNKAMqAFuznW7iLg9IsojorykpKTQGGZmlmU49twvB16JiFqAiKiNiK6I6AbuAM4fhscwM7M8DEe5rybrkIykxVnzPgNUDsNjmJlZHob8UUgASTOATwJfyhr+vqQyMt/92N1rnpmZjYCCyj0iWoD5vcauLSiRmZkVzN9QNTNLIZe7mVkKudzNzFLI5W5mlkIudzOzFHK5m5mlkMvdzCyFXO5mZinkcjczSyGXu5lZCrnczcxSyOVuZpZCLnczsxRyuZuZpZDL3cwshVzuZmYp5HI3M0uhgs7EBCBpN3AI6AI6I6Jc0jzgfqCUzKn2Ph8RjYU+lpmZDc5w7bn/ZkSURUR5cn0t8GRErAKeTK6bmdkIKdZhmauAu5Ppu4Gri/Q4ZmaWw3CUewBPSNogaU0ytigiagCSnwt730jSGkkVkirq6+uHIYaZmfUo+Jg7cGFE7JO0EFgvaetgbhQRtwO3A5SXl8cw5DAzs0TBe+4RsS/5WQc8BJwP1EpaDJD8rCv0cczMbPAKKndJJ0ma1TMNfAqoBB4BrksWuw74WSGPYzba9h9uH+0IZnkp9LDMIuAhST339eOI+Lmkl4EHJF0P7AE+V+DjmI2qppYOFsycOtoxzAatoHKPiJ3AOTnGDwAXF3LfZmY2dP6GqplZCrnczcxSyOVuZpZCLnczsxRyuZsNikY7gFleXO5mZinkcjczSyGXu5lZCrnczcxSyOVuZpZCLnczsxRyuZsNik85YOOLy91sEJpbO0c7glleXO5mg3DTI6+NdgSzvLjczQZhW92h0Y5glheXu9kgtB3tHu0IZnlxuZuZpdCQy13SUklPS3pD0muSvpqM3ySpWtKm5HLF8MU1M7PBKOQ0e53A1yPileQk2RskrU/m3RoRf114PDMzG4ohl3tE1AA1yfQhSW8Apw5XMDMzG7phOeYuqRQ4F3gxGbpB0mZJ6yTN7eM2ayRVSKqor68fjhhmZpYouNwlzQQeBL4WEQeB24DTgDIye/Y357pdRNweEeURUV5SUlJoDDMzy1JQuUuaTKbYfxQRPwWIiNqI6IqIbuAO4PzCY5qZWT4K+bSMgDuBNyLilqzxxVmLfQaoHHo8MzMbikI+LXMhcC2wRdKmZOwbwGpJZWT+0tJu4EsFJTQzs7wV8mmZ58h91uDHhh7HzMyGg7+hamaWQi53M7MUcrmbmaWQy93MLIVc7mZmKeRyNzNLIZe7mVkKudzNzFLI5W5mlkIudzOzFBr35d7ccpSOTp+82Mws27gv93P+4gn+4J6K0Y5hZjamjPtyB3j2LZ/JyYqvdO2j3Lr+rdGOYTYoqSh3s5Hyt09uG+0IZoPicjfLk9/jsfHA5W7Wj4++Z8EJYzvqD+d9Pw9vrKbuYNtwRDIblHFd7hHx7nRVY8soJrE0mjRBnLN09gnjl//tL4+7/rNN1Xz4f/2C0rWPcrDt6AnLt3R08rX7N/H5//PrIeVoPNJBd3fw6x0H6OqOgW8wRu0/3D7ij1l/qJ1LbnmGtw8cGfHHziUi+Np9G/mn53exve4wp//p4+wcws7CYBSt3CVdJulNSdslrS3GY1Q1tr47/dHvPX1c2adVe2cXbUe7Thjv7Ooe1PPvXQ61B9t4emsdzS0nllLP8t3JbTo6u7nr+V0c7crvsERnVzfNrcfuv6s7qGlu5Z3m3HuytQfb2NuQ33/WLR2dxz1GtrajXXltG+80t/GH91TQ2U+Rlq599N3LV+/bRO3BTHGdfdMTPLSxip9X1rC3oYV3mtv4yF89BcDuAy38ZEMVP698h7dqD1FZ3UxN87Ft+JU9jdQfOlaAD2+s5safbuHcb69n5TceY/UdL/DBv1z/7vyd9Yf51Y79Jzy3nfWH6Uz+jbZUNVPd1ErdobY+13e+ntpaS0tHJzXNrZSufZSXdjX0u/zRrm5K1z5K+V/+ggcq9tLa0dXvNtTa0cXGPY08tbWWt2oP9bncq3ubTrifA4fbj9t2rr3zRbbXHeb//nLXcctFRM7fo5551U2tvLavmTue3fnu+Is7D1C69lGqm479mzUc6eDuX+0+7lDdhrcbeW7bfgAe3FDFLVlvwr+0q4GHN+3jz//f61xyyzN0dHbziZuf6fM5FkLFKERJE4G3gE8CVcDLwOqIeD3X8uXl5VFRkf/HGbdUNfPpHzx33Nh7F81i2fwZCOiOoLHlKBvebjzhtpecsZCypXP46ycyK/67v3MWW6qb+dGLewCYMmkCHZ3dzDtpCg1HOvrMsGDmFPYfPjb/ixeuYN3zx29Iv3veEh58perd66sWzmRb3eEkxyLqDrWxuar5hPteOm86bUe7mTN9Muctm0t7ZxcPb9o3wFrJbdrkCbQdLe6x4hlTJtLSkfsXZjjMP2kKjS0djOTO6yVnLOKbV55BTXMrv3fHiyP3wKNo0gSd8B/b2Utm59xGh0vP78iCmVOLvod/ztI5vLq3adjv99oLlvPy7ga2vtP3f0h92f3dK4f0mJI2RER5znlFKvePADdFxKXJ9RsBIuKvci0/1HLv6Ozm9D99vJCoZv269oLlfPvqMwHY8HYDz28/cNyemNlwKEa5D/kE2QM4Fdibdb0K+HCvUGuANQDLli0b0oNMmTThuJUSEXR1B7WH2nnznYMIUd3UysG2o8yaOomqxlZOnj6Zw+2dzJw6iUUnT+Pnle/Q1d3N7314OS0dnWza28Qzb9bzodJ57Kg/zL6mVhbNnsaFpy2gsaWDnfVHqG5qZU9DC1+8cAXTp0zgp69Uc6itk6vKTuEjp83n4Y3VdHQFZUtms7m6mSVzp/PK203MmjaJSRNFa0cXM6ZM4rnt+/nE+xZy1qmzaevsYv+hDiYINuxppP1oN9VNrcycOonPly+lpaOT/YfbWT7/JE6eNpnpUyZQ09xGc8tRTp07naVzZzBnxmS+/ejrXHHmYhaePI26g22UzJpKa0cXexpauPiMRTS1dPDirgaWzJ1OdwTly+excW8TpfNn8Pz2A6wsOYlfbqvntX0HOdTWydc/eTpvN7RQ1dhCe2c3F6ycT01TK8/vOMCFp83ncHsnJbOm8XhlDSUzp7JiwUlMnzKRX+04wJzpk1k+fwYzpkyisaWDTXub+M+nzWfZvBlccdZintpax3sWzqTxSAf3vrSXs5fM5p2DbUycIM48ZTZzZkzmSHsXM6dNYkf9YTo6u5k8Uby8u5HfOL2E57dnXvpOnCBK55/EmafO5s7ndvKFC1cwcYLY09DCE6/V8r3fPYtvPLSF0xfN4qVdDfzRRe/htIUnsf71WlYumMny+TPY+s4httcdYvqUSbS0d3LByvlIcP1HV7y7fX1w+Tw+uHweV569mK7uYM+BFlYtmsm9L+3l1zsP8O2rPsAv3qhj455Gfpm8LAd48uu/we79R9hS3czf/GIbq89fxv0v76E7oHT+DH7zfQuprG7m7CVzuPO5Y6/6PlQ6l6aWo2yrO8zVZaewc/+RfveeZ0+f3OehqTNPPZnWji521A/+2PPUSRNo7/XJoIveW8K/v1nY90quPGsxj26pAeBjqxYct64G0t9zzGXx7Gm8Z+HMEx5j2bwZ7Mlx6K+/V7ifPucUBDzy6vGvnq8uO+W4V9RzZ0zmSEcXkyeI0//TLDbuOfFVwqUfWER3ZA5Zff+zZw/6+eSjWHvunwMujYg/SK5fC5wfEf8j1/JD3XM3M/uPrL8992K9oVoFLM26vgQY2sFiMzPLW7HK/WVglaQVkqYA1wCPFOmxzMysl6Icc4+ITkk3AP8GTATWRcRrxXgsMzM7UbHeUCUiHgMeK9b9m5lZ38b1N1TNzCw3l7uZWQq53M3MUsjlbmaWQkX5ElPeIaR64O0C7mIBMPivuY0uZy2e8ZR3PGWF8ZV3PGWFwvIuj4iSXDPGRLkXSlJFX9/SGmuctXjGU97xlBXGV97xlBWKl9eHZczMUsjlbmaWQmkp99tHO0AenLV4xlPe8ZQVxlfe8ZQVipQ3FcfczczseGnZczczsywudzOzFBrX5T4SJ+HOl6TdkrZI2iSpIhmbJ2m9pG3Jz7nJuCT9XZJ/s6TzRiDfOkl1kiqzxvLOJ+m6ZPltkq4bwaw3SapO1u8mSVdkzbsxyfqmpEuzxkdkO5G0VNLTkt6Q9JqkrybjY2799pN1TK5fSdMkvSTp1STvnyfjKyS9mKyn+5M/MY6kqcn17cn80oGexwhkvUvSrqx1W5aMF2c7iIhxeSHzp4R3ACuBKcCrwPvHQK7dwIJeY98H1ibTa4HvJdNXAI8DAi4AXhyBfB8HzgMqh5oPmAfsTH7OTabnjlDWm4D/mWPZ9yfbwFRgRbJtTBzJ7QRYDJyXTM8ic5L494/F9dtP1jG5fpN1NDOZngy8mKyzB4BrkvEfAv89mf4j4IfJ9DXA/f09jxHKehfw2RzLF2U7GM977ucD2yNiZ0R0APcBV41ypr5cBdydTN8NXJ01fk9kvADMkbS4mEEi4lmgocB8lwLrI6IhIhqB9cBlI5S1L1cB90VEe0TsAraT2UZGbDuJiJqIeCWZPgS8QeZ8wmNu/faTtS+jun6TdXQ4uTo5uQTwCeAnyXjvdduzzn8CXCxJ/TyPkcjal6JsB+O53HOdhLu/jXOkBPCEpA3KnAQcYFFE1EDmlwpYmIyPleeQb77Rzn1D8vJ1Xc8hjn4yjUrW5DDAuWT22sb0+u2VFcbo+pU0UdImoI5M0e0AmiKiM8djv5srmd8MzB+pvL2zRkTPuv1Osm5vlTS1d9ZemQrKOp7LXTnGxsLnOi+MiPOAy4GvSPp4P8uO1efQo698o5n7NuA0oAyoAW5OxsdMVkkzgQeBr0XEwf4WzTE2oplzZB2z6zciuiKijMw5mc8HzujnsUc1b++sks4EbgTeB3yIzKGWPylm1vFc7mPyJNwRsS/5WQc8RGYjrO053JL8rEsWHyvPId98o5Y7ImqTX5xu4A6OvaQeE1klTSZTlj+KiJ8mw2Ny/ebKOtbXb5KxCfh3Msen50jqOaNc9mO/myuZP5vMIb4RzZuV9bLkUFhERDvwTxR53Y7nch9zJ+GWdJKkWT3TwKeAyiRXzzvd1wE/S6YfAX4/ebf8AqC55+X7CMs3378Bn5I0N3nZ/qlkrOh6vSfxGTLrtyfrNcmnJFYAq4CXGMHtJDmmeyfwRkTckjVrzK3fvrKO1fUrqUTSnGR6OnAJmfcJngY+myzWe932rPPPAk9F5l3Kvp5HsbNuzfoPXmTeG8het8O/HQz1HeGxcCHzLvNbZI69fXMM5FlJ5p34V4HXejKROdb3JLAt+Tkvjr2r/g9J/i1A+QhkvJfMy+2jZPYMrh9KPuCLZN6M2g58YQSz/nOSZXPyS7E4a/lvJlnfBC4f6e0E+CiZl82bgU3J5YqxuH77yTom1y9wNrAxyVUJ/FnW79xLyXr6V2BqMj4tub49mb9yoOcxAlmfStZtJfAvHPtETVG2A//5ATOzFBrPh2XMzKwPLnczsxRyuZuZpZDL3cwshVzuZmYp5HI3M0shl7uZWQr9fzkD5pRb0+SQAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, 0.4838709677419355, 0.28, 0.50187265917603]"
      ]
     },
     "execution_count": 213,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Adam, dpc=4, tanh cat fn, output freezing, output_size=5\n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.20973782771535582, 0.0, 0.468, 0.4166666666666667]"
      ]
     },
     "execution_count": 207,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Adam, dpc=4, tanh cat fn, no freezing, output_size=5\n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, 0.7651515151515151, 0.0, 0.14516129032258066]"
      ]
     },
     "execution_count": 154,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dpc = 1, SGD lr=0.1, # no cat act_fn, no output freezing, output_size=11\n",
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in train_inds.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f4d1ea9b890>]"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAaoklEQVR4nO3deXhc9X3v8fd3ZiSNtdmyJW94kQ0GY4ixwSZsoew4hgRuCy3pDTfATelt0oSkN02chASSlsBtSJpe7r3NQxNCaAlJytLwlLAbSCCJjYwN2JY3bCG8S8ja15n53T/mSJaNJUu2juYcnc/refxoljPSR4fho9/8zmbOOUREJLhiuQ4gIiKDU1GLiAScilpEJOBU1CIiAaeiFhEJuIQf37S8vNxVVlb68a1FRMakNWvW1DvnKo70nC9FXVlZSVVVlR/fWkRkTDKzdwd6TlMfIiIBp6IWEQk4FbWISMCpqEVEAk5FLSIScCpqEZGAU1GLiASciloio6a+jde21ec6hsiw+XLAi0gQXXTvywDU3HNVboOIDJNG1CIiAaeiFhEJOBW1iEjAqahFRAJORS0iEnAqahGRgFNRi4gEnIpaRCTgVNQiIgE3pKI2sy+a2QYzW29mj5hZ0u9gIiKSddSiNrMTgM8DS5xzpwNx4Aa/g4mISNZQpz4SwDgzSwCFwG7/IomISH9HLWrn3C7gXqAW2AM0OeeeO3w5M7vVzKrMrKqurm7kk4qIRNRQpj7KgGuAOcB0oMjMPnn4cs65+51zS5xzSyoqKkY+qYhIRA1l6uMyYIdzrs451wM8DpznbywREek1lKKuBc4xs0IzM+BSoNrfWCIi0msoc9SrgEeBN4C3vdfc73MuERHxDOkKL865O4A7fM4iIiJHoCMTRUQCTkUtIhJwKmoRkYBTUYuIBJyKWkQk4FTUIiIBp6IWEQk4FbWISMCpqEVEAk5FLSIScCpqEZGAU1GLiAScilpEJOBU1CIiAaeiFhEJOBW1iEjAqahFRAJORS0iEnAqahGRgFNRi4gEnIpaRCTgVNQiIgGnohYRCTgVtYhIwKmoRUQCTkUtIhJwKmoRkYBTUYuIBJyKWkQk4FTUEjnOuVxHEBkWFbVEjnpawkZFLZGjnpawUVGLiASciloiR3PUEjYqaokc1bSEjYpaIkcDagkbFbVEjtOYWkJmSEVtZhPM7FEz22Rm1WZ2rt/BREQkKzHE5f4JeMY5d52Z5QOFPmYS8ZWmPiRsjlrUZlYKXAjcBOCc6wa6/Y0lIiK9hjL1MReoA35iZmvN7EdmVnT4QmZ2q5lVmVlVXV3diAcVGSkaUUvYDKWoE8CZwD875xYDbcCKwxdyzt3vnFvinFtSUVExwjFFRo42JkrYDKWodwI7nXOrvPuPki1uEREZBUctaufcXuA9MzvFe+hSYKOvqUR8pKkPCZuh7vXxOeBhb4+P7cDN/kUS8Zd6WsJmSEXtnFsHLPE5i8io0Lk+JGx0ZKJEjmpawkZFLSIScCpqiRzNfEjYqKglelTUEjIqaokcHfAiYaOilsjR1IeEjYpaRCTgVNQSORpQS9ioqCVydMCLhI2KWiJHNS1ho6KWyNGAWsJGRS0iEnAqaokc7UctYaOiluhRT0vIqKglctTTEjYqaokcbUyUsFFRi4gEnIpaIkcbEyVsVNQSOZr6kLBRUUvkqKclbFTUEjk614eEjYpaRCTgVNQSORpQS9ioqEVEAk5FLZGjEbWEjYpaIkf7UUvYqKhFRAJORS2Ro6kPCRsVtUSOelrCRkUtkaMDXiRsVNQSOappCRsVtUSGWa4TiBwbFbVEjmY+JGxU1BIZBwfUamoJFxW1RIZ5cx8aUUvYqKglctTTEjYqaokMbUuUsBpyUZtZ3MzWmtl/+hlIxG+a+pCwGc6I+jag2q8gIn7r3T1PJ2WSsBlSUZvZDOAq4Ef+xhHxj6GNiRJOQx1R/wD4MpDxMYuIv3pH1CpqCZmjFrWZXQ3sd86tOcpyt5pZlZlV1dXVjVhAkZGijYkSVkMZUZ8PfNzMaoCfA5eY2b8dvpBz7n7n3BLn3JKKiooRjikycjRHLWFz1KJ2zn3VOTfDOVcJ3ACsdM590vdkIj7R1IeEjfajlsjQSZkkrBLDWdg59zLwsi9JRHymvT4krDSilsjQiFrCSkUtkaONiRI2KmqJjN4BtaY+JGxU1BIZfac5zXEOkeFSUUtkHBxRq6olXFTUIiIBp6KWyNF4WsJGRS3RoZMySUipqCUydHFbCSsVtUSGLm4rYaWilsjQkYkSVipqiRwNqCVsVNQSGToyUcJKRS2RcXCOWk0t4aKilsjoG1HnNIXI8KmoRUQCTkUtkaOZDwkbFbVERu/ueToftYSNiloipK+pRUJFRS2RYeppCSkVtUSGDkyUsFJRS+RoY6KEjYpaIkMbEyWsVNQSGYbOnifhpKKWyFFPS9ioqCUydJpTCSsVtUSOTsokYaOilsjQSZkkrFTUEhmmI14kpFTUEjnaPU/CJjBF7Zzjsz97g0fX7Mx1FBmjtDFRwiowRW1m/G5bPWtrD+Q6ioxx2pYoYROYogaoKClgf0tXrmPIGNU3Ra2ilpAJVFFPLkmqqMV36mkJm4AVdQH1KmrxycFDyFXVEi6BKuqK0gLqWrr0P5L4QhsTJawCVdSTS5J0pzM0tvfkOoqMYRoGSNgErKgLADRPLb7oOzJRTS0hc9SiNrOZZvaSmVWb2QYzu82vMBV9Rd3p14+QCOs7MlFjagmZxBCWSQH/0zn3hpmVAGvM7Hnn3MaRDtM7oq7TiFp8oBG1hNVRR9TOuT3OuTe82y1ANXCCH2EmlyYBTX2IT7QxUUJqWHPUZlYJLAZWHeG5W82sysyq6urqjilMcUGCkmSC3Y0dx/R6kaHQgFrCZshFbWbFwGPAF5xzzYc/75y73zm3xDm3pKKi4pgDzSwr5L2G9mN+vchAegfUGc19SMgMqajNLI9sST/snHvcz0CzJhZSq6IWHyRi2bd7OqOilnAZyl4fBvwYqHbOfd/vQNMnjGNPU6cOepERF49lx9QqagmboYyozwduBC4xs3Xev+V+BZo+IUl7d5qmDh30IiMrEc8WdUpFLSFz1N3znHOvMorby2dPKgJge30bZ87KH60fKxGgEbWEVaCOTASYW5Et6pr6thwnkbEmEdOIWsIpcEU9o2wcZlDzvjYoysjqG1GnMzlOIjI8gSvqgkScyklFbN77gT0ARY5L714fGlFL2ASuqAEWTC9l4x4VtYysmOaoJaSCWdTTSnmvoUN7fsiI0hy1hFUgi/q06aUAVGtULSNIe31IWAWyqBd4Rb1ht4paRo7X0xpRS+gEsqgnlySpKClgo4paRlDvNRPTGe31IeESyKKG7Dz1ht1NuY4hY5BG1BI2gS3qM2ZOYMu+Fg60dec6iowx6bSKWsIlsEV9yfzJZBy8suXYzm0tMhCNqCVsAlvUC08YT3lxPis37c91FBljtNeHhE1gizoWMy4+ZTIvb95PSof8yghw3rVdevR+kpAJbFEDXHrqZJo7U1S9eyDXUWQM6UqpqCVcAl3UF8yrYFxenMff2JnrKDKGdKXSuY4gMiyBLuriggTXLj6BX1btpEF7f8gI6ezRiFrCJdBFDbDs9KkAfOy+V3OcRMYKjaglbAJf1BfOKwdgV2MHrV2pHKeRsaBLI2oJmcAXtZnxy788F4DT73hWoyE5bp16D0nIBL6oAZbMLuu7fcrtz1DX0pXDNBJ263fpHDISLqEo6ljMmD+1pO/+0rteyGEaGQv0x17CJBRFDfDU5z9yyP3/s3JrjpJIWLl+ByQuvesFurU/tYREaIo6HjOqv72Mry2fD8C9z22hcsVTdPZovlGGbvakwr7bJ9/+NJUrnuKP/99rPLF2p04AJoFlzo38eQ+WLFniqqqqRvz79qpc8dQh92vuucq3nyVjx60PVVHb0M51Z83g75+qHtJr/mzJTL505SlUlBT4nE6izszWOOeWHOm50Iyo+3vzjisOuV+54in2NHXkKI2Ezac/MpfffvniIS37i6r3WHrXCzz55m6fU4kMLJHrAMdi/Lg8tn9nOX/+oz/wh+0NAJx790oA/u6a07jx3MocppMwmDmx8JBPYq1dKfY3d/LaO+9TvaeZn62qPWT5zz+ylvKifM47qXy0o4qEc+qjv1e31vPJH6865LFZEwt54W/+iI6eNOPH5Y1KDgm+3qmPZ75w4ZCW7+xJM/8bzxzy2E9uWsrF8yf7EU8ibsxNffR3wbxyHvurcw95rLahnZNvf5ozvvUcrV0pHvp9DbXvt+cmoIRWMi9OzT1X8T/+6MS+x25+8HW9l2TUhb6oAc6aPZEddy/nkiOMdE6/41m++asNXPjdl3KQTMaCFR+dz+a/X9Z3/8LvvsQ7da05TCRRMyaKGrKHmj9w01J+9hcfHnCZS+59mcXffo7KFU+xflcTe5s68WPqR8aegkScdd+8vO/+pd97JYdpJGrGTFH3Ou/EcmruuYpXv/LBrfrb69s40N4DwNX3vco5d7/ILQ++rn2xI+J4/yRPKMznm1cv6Lu/v6XzOL+jyNCMuaLuNaOskKrbL+O2S+cNutxLm+uY/41nqFzxFOffs7LvMk3dqYxG22OQmR3X62+5YE7f7bPvevF444gMyZgtaoDy4gK+ePnJrP7apUNafldjB4u+9RwrN+3j5Nuf5qHfv+tzQgmj6m8fnK++/oe/y2ESiYoxXdS9JpcmqbnnKt668wpuPGc2D3964Hnstu40tzyY3bXwjic38PCqd9m0t7lv41FPOsOepg6aOno+8Np0xvHCxn0aiY9x4/LjPPGZ8wB4vebAB46UFRlpkSjqXqXJPP7u2tM5/6Ry/uFPFvKReeWUJAc/5ufrT6xn2Q9+y6Xfe4Ut+1qY9/WnOffulXz4O9kz+O1p6uC5DXsBeODVHXz6oSqe9e7L2LV4VhnXnzWj777KWvwUyiMTR8KfLp3Jny6d2Xe/pr6Ni+59edDXXPGPv+m73dmT4eafrOalzXUATC1Nsrc5u3Fpv3cKzV+t28WZs8qYObGQzp40rV0pyot1zoix4rvXn4EDHl2Tvfhy5YqnuPNjC7jp/DmDv1BkmEJ/ZKIfaurbuG/lNh4boauf3/eJxXzukbUAXLtoOnf/8UJqG9o5oWwcr22rZ+GM8dy3chs3n1fJHU9u4KqF0/ivH559yPfY1dhBUX6cCYX5ADR39pAXi/Hcxr18/IzpvN/WzYG2buZNKdEfhQH8xUNV7DzQwdO3feToCw/Dsh/8hk17Ww557OJTKvjhjWdRkIiP6M+SsWuwIxOHVNRmtgz4JyAO/Mg5d89gy4e9qPtr7uxh1fYG8uLGd35dzZZ9uT/Q4W+vPIXvPru57355cQHvt3XhHLzytxfxtSfe5rVt7/P4Z85j4+5mrjtrBr9at4tnN+xjSmmSv7n8ZMoK82ho62ZvcyezJhYyoTCfh35fw5TSJFeelr2gcHt3ipgZ79S1MqU0ecTib2zvZldjB6dNH+/b79vRnSY/ESMeG/4eG92pDBt2N7F4VplvRQ3wTl3roPtWn3/SJC6cV8EtF8yhpTNFUUGcVNpRVBDZD7VymOMqajOLA1uAy4GdwOvAJ5xzGwd6zVgq6iP5xeu1zCgrZEppksu+rwMfAPLjMbrTB0/Ef9XCadTUt1FT30Zb96H7qX/xspNJxI1djR19Jz/6xtULWPNuA2++18TVZ0xj9Y4G1tY2ctmpU3ihel/fa69dNJ3/WJc9k91nLz6Rx9bsYt6UYhrbe3h7VxNTS5NccdoUahvauXzBFO57cVvflBTAnPIiXvrSRb6th8b2bm77+Tpe2VI3ot/3vBMnkRePcc7cSfx2ax1NHT2k0o4vXn4yXak0p0wt4Q/vvM/KzXUsmjGeyxdMpSSZ4JHXa5k9sYiedIYF00spLkiQzIuzo76VZCLOolkTiMeM5zfuY2nlRArz46x59wALZ0xgX3MnrV0pTqoopieT6TtvTkNbN4lYjIxzrK1tZMG0UqaOT/JG7QFmTiykoriAls4eJhUX0NzZQzIRJy9udKczfZ8wUukM7T1pSgoSA+4yub+lk4riAjIOat5vo3JSEe3dKVq7UsRjRkVxAbUN7cyeVIRzDjMjk3GYZXfD3Li7me50hkUzJxzyfTt70iTzsjl60hliZh8YBKTSGeIxw8zoTmVIZTIYxrj87OuaOnqIx4xEzGjrSjGpuADnHHWtXUwuSR7Tf+PjLepzgTudc1d6978K4Jy7e6DXjPWi7i+TcaQyjvxEdrtse3eK/HiM9p40Te09pDOOX6/fQ11LF+fMncRf/uuaHCeW0Tp/+fa6Vho7enh+4z7+5TfbSWUcRfnxD/zhigKz7BV2er8ertTbqN+TdsQMYma0dKWG9TOKCxK0eq9J5sXo9K42X5gfp6ggQXtXirRzdPZkKC8uoL714OXYSpMJir0/GumM6/vjPm18kj1NB//QV5QUkB+Psavxg6dVnlCYR3FBgt9++eJj2l9/sKIeyueuE4D3+t3fCXxg/zYzuxW4FWDWrFnDDhlWsZiR3++vcWF+dpWWxmOUJrMjkM9cdFLf80crid6RwUDPOQfd6QxdqQwFiVjfyKC5s4eYGW/tbCQvHmNG2Tje3tlEQ1s3q3c0cNP5lUwtTfK957aQyjiKC+Ik4jG27m+lrSvFh04Yz9yKIlbvaOCVLXWUFeYzLi9OUUGcwvwEZtCVyrB6R3Ya6LJTp/D0+r3EY8a08UnqWrro8i5tddr0Ulo6U9Q2tDNrYiG1De2UFeb1HRV6JItnTWBtbeNR13dBItb3c47FvdefccyvHa65FcUAnDmrjK8smz/gcpmMY19LJy2d2ZKp3tPMtPHjaOroIREztte30daVYuPuZk6dVkpzZw9NHT1s3d/KZfMnU723mQXTSvn+81vIeCW4/ENT+dAJE/jp72rY29xJeXE+U0qTpDOOEycXs21fK7MnFVKYHydmxuNrd1FenM+siYXsb+mirDCfXY0dlBfnk844xo/Lo7K8iI7uNOt3N1GUn6A0mcfa9w7Qkz60eS9fMIXnN+7jE2fP5JHV71GSTHDBSeW8urWeq8+YTmN7N6/XNJBx2U84b+9s4qqF08mPG6mMozuVoSAvxovV+zl37iQaO3pYuWk/p00vZWJRPpv2tnCgrZuPL5rO42/sYvmHptLalWZG2Thq32/n3YY2Tp8+nj1NnWza28zHFk4nFuu9Tqbx2rZ6zjtxEvuaO1nlvZ+nlCY5Y+YEnIN4LHsB5Ia2bs6cXcaL1fuImdHenebsyokk8+J9268mFuXT0NbN3PIiFs8qY2llGamMIy9+fAdWHW4oI+rrgSudc5/27t8InO2c+9xAr4nSiFpEZCQc72lOdwIz+92fAehyFyIio2QoRf06MM/M5phZPnAD8KS/sUREpNdR56idcykz+2vgWbK75z3gnNvgezIREQGGeGSic+7XwK99ziIiIkcQqXN9iIiEkYpaRCTgVNQiIgGnohYRCThfzp5nZnXAsV4epRyoH8E4fgpTVghX3jBlhXDlDVNWCFfe48k62zlXcaQnfCnq42FmVQMdnRM0YcoK4cobpqwQrrxhygrhyutXVk19iIgEnIpaRCTggljU9+c6wDCEKSuEK2+YskK48oYpK4Qrry9ZAzdHLSIihwriiFpERPpRUYuIBFxgitrMlpnZZjPbZmYrcp2nl5nVmNnbZrbOzKq8xyaa2fNmttX7WuY9bmb2v73f4S0zO9PnbA+Y2X4zW9/vsWFnM7NPectvNbNPjXLeO81sl7d+15nZ8n7PfdXLu9nMruz3uO/vFTObaWYvmVm1mW0ws9u8xwO3fgfJGtR1mzSz1Wb2ppf3W97jc8xslbeefuGdVhkzK/Dub/Oerzza7zEKWR80sx391u0i73F/3gfZyzvl9h/Z06e+A8wF8oE3gQW5zuVlqwHKD3vsH4AV3u0VwP/ybi8HngYMOAdY5XO2C4EzgfXHmg2YCGz3vpZ5t8tGMe+dwJeOsOwC731QAMzx3h/x0XqvANOAM73bJWQv8LwgiOt3kKxBXbcGFHu384BV3jr7JXCD9/gPgb/ybn8G+KF3+wbgF4P9HqOU9UHguiMs78v7ICgj6rOBbc657c65buDnwDU5zjSYa4Cferd/Clzb7/GHXNYfgAlmNs2vEM653wANx5ntSuB551yDc+4A8DywbBTzDuQa4OfOuS7n3A5gG9n3yai8V5xze5xzb3i3W4BqstcPDdz6HSTrQHK9bp1zrtW7m+f9c8AlwKPe44ev2951/ihwqZnZIL/HaGQdiC/vg6AU9ZEuoDvYG200OeA5M1tj2Qv4Akxxzu2B7P8kwGTv8SD8HsPNFoTMf+19THygdyphkFyjntf7qL2Y7Ggq0Ov3sKwQ0HVrZnEzWwfsJ1ta7wCNzrneS4/3/9l9ubznm4BJo5X38KzOud51e5e3bv/RzAoOz3pYpuPKGpSiPtIle4Oy3+D5zrkzgY8CnzWzCwdZNsi/x0DZcp35n4ETgUXAHuB73uOByGtmxcBjwBecc82DLXqEx0Y17xGyBnbdOufSzrlFZK/BejZw6iA/O6d5D89qZqcDXwXmA0vJTmd8xc+sQSnqwF5A1zm32/u6H3iC7JtqX++Uhvd1v7d4EH6P4WbLaWbn3D7vf4QM8C8c/Oia87xmlke2+B52zj3uPRzI9XukrEFet72cc43Ay2TncyeYWe9Vp/r/7L5c3vPjyU6hjWreflmXedNNzjnXBfwEn9dtUIo6kBfQNbMiMyvpvQ1cAawnm613q+2ngF95t58E/pu35fccoKn3Y/IoGm62Z4ErzKzM+2h8hffYqDhsDv+/kF2/vXlv8Lb4zwHmAasZpfeKNwf6Y6DaOff9fk8Fbv0OlDXA67bCzCZ4t8cBl5GdV38JuM5b7PB127vOrwNWuuwWuoF+D7+zbur3x9rIzqX3X7cj/z441q2hI/2P7NbSLWTnqr6e6zxeprlktyq/CWzozUV2fuxFYKv3daI7uIX4/3q/w9vAEp/zPUL2I20P2b/Y//1YsgG3kN0Qsw24eZTz/quX5y3vTT6t3/Jf9/JuBj46mu8V4AKyH03fAtZ5/5YHcf0OkjWo63YhsNbLtR74Zr//31Z76+nfgQLv8aR3f5v3/Nyj/R6jkHWlt27XA//GwT1DfHkf6BByEZGAC8rUh4iIDEBFLSIScCpqEZGAU1GLiAScilpEJOBU1CIiAaeiFhEJuP8PABY47j8ApskAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 0.616, 0.0, 0.38257575757575757]"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[evaluate_model(net, experiment.test_loader[k], torch.device(\"cuda\"))[\"mean_accuracy\"] for k in range(4)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearToyNetwork(nn.Module):\n",
    "    \"\"\" Using dendrite layers for input\n",
    "    Note you will have to adjust the training loop a bit\n",
    "    to accomodate the image flattening\n",
    "    \"\"\"\n",
    "    def __init__(self, dpc=3,\n",
    "                 linear1_w_sparsity=0.5,\n",
    "                 linear2_w_sparsity=0.1,\n",
    "                 cat_w_sparsity=0.01,\n",
    "                n_classes=4,\n",
    "                batch_norm=True):\n",
    "        super(ToyNetwork, self).__init__()\n",
    "        linear1_channels = 1000\n",
    "        linear2_channels=512\n",
    "        self.batch_norm = batch_norm\n",
    "        \n",
    "        self.n_classes = n_classes\n",
    "        self.linear1 = SparseWeights(nn.Linear(32*32, linear1_channels), linear1_w_sparsity)\n",
    "        self.kwin1 = KWinners(linear1_channels, percent_on=0.1)\n",
    "        self.bn = nn.BatchNorm1d(linear1_channels, affine=False)\n",
    "                \n",
    "        self.d0 = DendriteLayer(in_dim=32*32,\n",
    "                               out_dim=linear1_channels,\n",
    "                               dendrites_per_neuron=dpc,\n",
    "                                act_fun_type=\"kwinner\",\n",
    "                               )       \n",
    "        self.d1 = DendriteLayer(in_dim=linear1_channels,\n",
    "                               out_dim=linear2_channels,\n",
    "                               dendrites_per_neuron=dpc,\n",
    "                                act_fun_type=\"kwinner\",\n",
    "                               )\n",
    "        \n",
    "        self.d2 = DendriteLayer(in_dim=linear2_channels,\n",
    "                               out_dim=n_classes+1,\n",
    "                               dendrites_per_neuron=dpc,\n",
    "                               act_fun_type=\"None\",\n",
    "                               )\n",
    "        self.cat = SparseWeights(nn.Linear(n_classes+1, linear1_channels*dpc), cat_w_sparsity)\n",
    "        \n",
    "        self.kwind1 = KWinners(linear2_channels, percent_on=0.1)\n",
    "        self.kwind2 = KWinners(n_classes+1, percent_on=0.1)\n",
    "        \n",
    "    def forward(self, x, label=None, batch_norm=False):\n",
    "#         y = self.linear1(x)\n",
    "        y = self.d0(x)\n",
    "        if self.batch_norm:\n",
    "            y = self.bn(y)\n",
    "        y = self.kwin1(y)\n",
    "        if label is not None:\n",
    "            yhat = torch.eye(self.n_classes+1)[label]\n",
    "            y_ = self.cat(yhat.cuda())\n",
    "            print(y_.shape)\n",
    "            y = self.d1(y, torch.sigmoid(y_))\n",
    "        else:\n",
    "            y = self.d1(y)\n",
    "        \n",
    "        y = self.kwind1(y)\n",
    "        y = self.kwind2(self.d2(y))\n",
    "        y = F.log_softmax(y, dim=1)\n",
    "        return y"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
